package com.run2code.idea.plugin.codegenerate.biz.db.utils;

import com.run2code.idea.plugin.codegenerate.biz.db.metadata.Column;
import com.run2code.idea.plugin.codegenerate.biz.db.metadata.Table;

import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;

/**
 * db操作工具类
 * 关注公众号【奔跑的码畜】，一起进步不迷路
 *
 * @author 第七人格
 * @date 2023/10/30
 */
public class DBUtils {

    /**
     * 数据库地址
     */
    private String host;
    /**
     * 端口号
     */
    private Integer port;
    /**
     * 用户名
     */
    private String username;
    /**
     * 密码
     */
    private String password;
    /**
     * 数据库名称
     */
    private String database;

    /**
     * 配置
     */
    private Properties properties;

    public DBUtils(String host, Integer port, String username, String password, String database) {
        this.host = host;
        this.port = port;
        this.username = username;
        this.password = password;
        this.database = database;

        try {
            Class.forName("com.mysql.jdbc.Driver");
        } catch (ClassNotFoundException e) {
            throw new RuntimeException(e.getMessage(), e);
        }

        properties = new Properties();
        properties.put("user", this.username);
        properties.put("password", this.password);
        properties.setProperty("remarks", "true");
        properties.put("useInformationSchema", "true");
    }

    /**
     * 获取数据库连接
     *
     * @return {@link Connection}
     */
    public Connection getConnection() {
        try {
            return DriverManager.getConnection("jdbc:mysql://" + this.host + ":" + this.port + "?useUnicode=true&characterEncoding=UTF-8&useSSL=false&autoReconnect=true&failOverReadOnly=false&serverTimezone=GMT%2B8", properties);
        } catch (SQLException e) {
            throw new RuntimeException(e.getMessage(), e);
        }
    }

    /**
     * 通过database，获取数据库连接
     *
     * @param database 数据库名称
     * @return {@link Connection} 连接信息
     */
    private Connection getConnection(String database) {
        try {
            return DriverManager.getConnection("jdbc:mysql://" + this.host + ":" + this.port + "/" + database + "?useUnicode=true&characterEncoding=UTF-8&useSSL=false&autoReconnect=true&failOverReadOnly=false&serverTimezone=GMT%2B8", properties);
        } catch (SQLException e) {
            throw new RuntimeException(e.getMessage(), e);
        }
    }

    /**
     * 关闭数据库连接
     *
     * @param conn 连接信息
     */
    private void closeConnection(Connection conn) {
        if (conn != null) {
            try {
                conn.close();
            } catch (SQLException ignored) {
            }
        }
    }

    /**
     * 通过database，获取所有的表名
     *
     * @param database 数据库名称
     * @return {@link List}<{@link String}> 表名称集合
     */
    public List<String> getAllTableName(String database) {
        this.database = database;
        Connection conn = getConnection(this.database);
        try {
            DatabaseMetaData metaData = conn.getMetaData();
            ResultSet rs = metaData.getTables(this.database, this.database, "%", new String[]{"TABLE"});
            List<String> ls = new ArrayList<>();
            while (rs.next()) {
                String s = rs.getString("TABLE_NAME");
                ls.add(s);
            }
            return ls;
        } catch (SQLException e) {
            throw new RuntimeException(e.getMessage(), e);
        } finally {
            closeConnection(conn);
        }
    }

    /**
     * 通过表名称，获取表信息
     *
     * @param tableName 表名
     * @return {@link Table} 表信息
     */
    public Table getTable(String tableName) {
        Connection conn = getConnection(this.database);
        try {
            DatabaseMetaData metaData = conn.getMetaData();
            ResultSet rs = metaData.getTables(null, "", tableName, new String[]{"TABLE"});
            Table table = null;
            if (rs.next()) {
                table = new Table(rs.getString("REMARKS"), tableName, getAllColumn(tableName, conn));
            }
            return table;
        } catch (Exception e) {
            throw new RuntimeException(e.getMessage(), e);
        } finally {
            closeConnection(conn);
        }
    }

    /**
     * 通过表名称，模糊查询所有相似的表
     *
     * @param tableName 数据库名称
     * @return {@link List}<{@link String}> 表名称集合
     */
    public List<String> getResembleTableName(String tableName) {
        Connection conn = getConnection(this.database);
        try {
            DatabaseMetaData metaData = conn.getMetaData();
            ResultSet rs = metaData.getTables(this.database, this.database, "%" + tableName + "%", new String[]{"TABLE"});
            List<String> ls = new ArrayList<>();
            while (rs.next()) {
                String s = rs.getString("TABLE_NAME");
                ls.add(s);
            }
            return ls;
        } catch (Exception e) {
            throw new RuntimeException(e.getMessage(), e);
        } finally {
            closeConnection(conn);
        }
    }

    /**
     * 获取表中所有的列
     *
     * @param tableName 表名
     * @param conn 连接
     * @return {@link List}<{@link Column}> 列集合
     */
    private List<Column> getAllColumn(String tableName, Connection conn) {
        try {
            DatabaseMetaData metaData = conn.getMetaData();
            ResultSet primaryKeys = metaData.getPrimaryKeys(null, null, tableName);
            String primaryKey = null;
            while (primaryKeys.next()) {
                primaryKey = primaryKeys.getString("COLUMN_NAME");
            }
            ResultSet rs = metaData.getColumns(null, null, tableName, null);
            List<Column> ls = new ArrayList<>();
            while (rs.next()) {
                String columnName = rs.getString("COLUMN_NAME");
                Column column = new Column(rs.getString("REMARKS"), columnName, rs.getInt("DATA_TYPE"));
                if (columnName.equals(primaryKey)) {
                    column.setId(true);
                }
                ls.add(column);
            }
            return ls;
        } catch (SQLException e) {
            throw new RuntimeException(e.getMessage(), e);
        }
    }

    /**
     * 测试连接信息
     *
     * @return {@link String}
     */
    public String testDatabase() {
        Connection conn = getConnection();

        try {
            PreparedStatement preparedStatement = conn.prepareStatement("SELECT VERSION() AS MYSQL_VERSION");
            ResultSet resultSet;
            resultSet = preparedStatement.executeQuery();
            while (resultSet.next()) {
                return resultSet.getString("MYSQL_VERSION");
            }
        } catch (SQLException e) {
            throw new RuntimeException(e.getMessage(), e);
        } finally {
            closeConnection(conn);
        }
        return "未知错误";
    }

}
